import cv2
import torch
import torch.utils.data as data
import torch.nn.functional as F
import numpy as np
import os
import random
import csv
import PIL
import skimage  
import albumentations as A

from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from torch.utils.data import DataLoader

SEMSEG_CLASSES = [2, 3, 4, 5, 6, 8, 9, 10, 12, 13, 15, 16] # [1, 2, 3, 4, 5, 7, 8, 9, 11, 12, 14, 15]
SEMSEG_CLASS_RANGE = range(1, 17)

all_tasks = ['class_object', 'class_scene', 'depth_euclidean', 'depth_zbuffer', 'keypoints2d', 'edge_occlusion', 'edge_texture', 'keypoints3d', 'normal', 'principal_curvature', 'reshading', 'rgb', 'segment_unsup2d', 'segment_unsup25d']
new_scale, current_scale, no_clip, preprocess, no_clip = {}, {}, {}, {}, {}

for task in all_tasks:
    new_scale[task], current_scale[task], no_clip[task] = [-1.,1.], None, None
    preprocess[task] = False
    no_clip[task] = False


current_scale['rgb'] = [0.0, 255.0]
# class_object', ' xentropy

# class_scene xentropy

# depth_euclidean l1_loss

# keypoints2d l1
current_scale['keypoints2d'] = [0.0, 0.005 * (2**16)]

# keypoints3d

current_scale['keypoints3d'] = [0.0, 1.0 * (2**16)] # 64000

# normal l1_loss

current_scale['normal'] = [0.0, 255.0]
# principal_curvature l2

# reshading l1
current_scale['reshading'] = [0.0, 255.0]
# segment_unsup2d metric_loss

# edge_texture l1
current_scale['edge_texture'] = [0.0, 0.08 * (2**16)]

# edge_occlusion l1

current_scale['edge_occlusion'] = [0.0, 0.00625* (2**16)]

no_clip['edge_occlusion'] = True

# segment_unsup2d
current_scale['segment_unsup2d'] = [0.0, 255.0]

# segment_unsup25d
current_scale['segment_unsup25d'] = [0.0, 255.0]

preprocess['principal_curvature'] = True

def curvature_preprocess(img, new_dims, interp_order=1):
    img = img[:,:,:2]
    img = img - [123.572, 120.1]
    img = img / [31.922, 21.658]
    return img

def rescale_image(im, new_scale=[-1.,1.], current_scale=None, no_clip=False):
    """
    Rescales an image pixel values to target_scale
    
    Args:
        img: A np.float_32 array, assumed between [0,1]
        new_scale: [min,max] 
        current_scale: If not supplied, it is assumed to be in:
            [0, 1]: if dtype=float
            [0, 2^16]: if dtype=uint
            [0, 255]: if dtype=ubyte
    Returns:
        rescaled_image
    """
    # im = skimage.img_as_float(im).astype(np.float32)
    im = np.array(im).astype(np.float32)
    if current_scale is not None:
        min_val, max_val = current_scale
        if not no_clip:
            im = np.clip(im, min_val, max_val)
        im = im - min_val
        im /= (max_val - min_val)
    min_val, max_val = new_scale
    im *= (max_val - min_val)
    im += min_val

    return im

from scipy.ndimage.filters import gaussian_filter
def rescale_image_gaussian_blur(img, new_scale=[-1.,1.], interp_order=1, blur_strength=4, current_scale=None, no_clip=False):
    """
    Resize an image array with interpolation, and rescale to be 
      between 
    Parameters
    ----------
    im : (H x W x K) ndarray
    new_dims : (height, width) tuple of new dimensions.
    new_scale : (min, max) tuple of new scale.
    interp_order : interpolation order, default is linear.
    Returns
    -------
    im : resized ndarray with shape (new_dims[0], new_dims[1], K)
    """
    # img = skimage.img_as_float( img ).astype(np.float32)
    # img = resize_image( img, new_dims, interp_order )
    img = rescale_image( img, new_scale, current_scale=current_scale, no_clip=True )
    blurred = gaussian_filter(img, sigma=blur_strength)
    if not no_clip:
        min_val, max_val = new_scale
        np.clip(blurred, min_val, max_val, out=blurred)
    return blurred

def resize_PIL(img, target_size):
    img_tensor = torch.from_numpy(np.array(img).astype(np.float32))

    is_single_channel = False
    if (len(img_tensor.shape) == 2):
        img_tensor = img_tensor.unsqueeze(-1)
        is_single_channel = True

    img_tensor = img_tensor.permute(2, 0, 1) # (C, H, W)

    img_tensor = F.interpolate(img_tensor.unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze(0)
    
    img_tensor = img_tensor.permute(1, 2, 0)
    if (is_single_channel == True):
        img_tensor = img_tensor.to(torch.int32)
    else:
        img_tensor = img_tensor.to(torch.uint8)
    img_tensor = img_tensor.squeeze()

    img_np = img_tensor.numpy()
    img = Image.fromarray(img_np)

    return img


def pad_by_reflect(x, padding=1):
    x = torch.cat((x[..., :padding], x, x[..., -padding:]), dim=-1)
    x = torch.cat((x[..., :padding, :], x, x[..., -padding:, :]), dim=-2)
    return x


class SobelEdgeDetector:
    def __init__(self, kernel_size=5, sigma=1):
        self.kernel_size = kernel_size
        self.sigma = sigma

        # compute gaussian kernel
        size = kernel_size // 2
        x, y = np.mgrid[-size:size+1, -size:size+1]
        normal = 1 / (2.0 * np.pi * sigma**2)
        g =  np.exp(-((x**2 + y**2) / (2.0*sigma**2))) * normal

        self.gaussian_kernel = torch.from_numpy(g)[None, None, :, :].float()
        self.Kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float)[None, None, :, :]
        self.Ky = -torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float)[None, None, :, :]

    def detect(self, img, normalize=True):
        squeeze = False
        if len(img.shape) == 3:
            img = img[None, ...]
            squeeze = True

        img = pad_by_reflect(img, padding=self.kernel_size//2)
        img = F.conv2d(img, self.gaussian_kernel.repeat(1, img.size(1), 1, 1))

        img = pad_by_reflect(img, padding=1)
        Gx = F.conv2d(img, self.Kx)
        Gy = F.conv2d(img, self.Ky)

        G = (Gx.pow(2) + Gy.pow(2)).pow(0.5)
        if normalize:
            G = G / G.max()
        if squeeze:
            G = G[0]

        return G


class TaskonomyDataset(data.Dataset):
    
    def __init__(self, img_types, data_dir='./data', split='tiny', partition='train', transform=None, resize_scale=None, crop_size=None, fliplr=False):
        
        super(TaskonomyDataset, self).__init__()
        
        self.partition = partition
        self.resize_scale = resize_scale
        self.crop_size = crop_size
        self.fliplr = fliplr
        self.class_num = {'class_object': 1000, 'class_scene': 365, 'segment_semantic':18}

        self.to_tensor = transforms.ToTensor()

        self.depth_quantiles = torch.load(os.path.join(data_dir, 'depth_quantiles.pth'))
        # [0.42, 0.62, 0.64, 0.68, 0.7000000000000001, 1.0]
        
        self.edge_params = torch.load(os.path.join(data_dir, 'edge_params.pth'))
        self.sobel_detectors = [SobelEdgeDetector(kernel_size=k, sigma=s) for k, s in self.edge_params['params']]

        self.edge_thresholds = torch.load(os.path.join(data_dir, 'edge_thresholds.pth'))

        def loadSplit(splitFile, full=False):
            dictLabels = {}
            with open(splitFile) as csvfile:
                csvreader = csv.reader(csvfile, delimiter=',')
                next(csvreader, None)
                for i,row in enumerate(csvreader):
                    scene = row[0]
                    if scene == 'woodbine': # missing from the dataset
                        continue
                    if scene == 'wiconisco': # missing 80 images for edge_texture
                        continue
                    no_list = {'brinnon', 'cauthron', 'cochranton', 'donaldson', 'german',
                        'castor', 'tokeland', 'andover', 'rogue', 'athens', 'broseley', 'tilghmanton', 'winooski', 'rosser', 'arkansaw', 'bonnie', 'willow', 'timberon', 'bohemia', 'micanopy', 'thrall', 'annona', 'byers', 'anaheim', 'duarte', 'wyldwood'
                    }
                    new_list = {'ballou', 'tansboro', 'cutlerville', 'macarthur', 'rough', 'darnestown', 'maryhill', 'bowlus', 'tomkins', 'herricks', 'mosquito', 'brinnon', 'gough'}
                    
                    if scene in new_list and full:
                        continue
                    if scene in no_list and (not full):
                        continue
                    is_train, is_val, is_test = row[1], row[2], row[3]
                    if is_train=='1' or is_val=='1':
                        label = 'train'
                    else:
                        label = 'test'

                    if label in dictLabels.keys():
                        dictLabels[label].append(scene)
                    else:
                        dictLabels[label] = [scene]
            return dictLabels

        self.split = split

        if split == 'tiny': 
            self.data = loadSplit(splitFile = os.path.join(data_dir, 'train_val_test_tiny.csv'))
        elif split == 'medium':
            self.data = loadSplit(splitFile = os.path.join(data_dir, 'splits_taskonomy/train_val_test_medium.csv'))
        elif split == 'fullplus':
            self.data = loadSplit(splitFile = os.path.join(data_dir, 'splits_taskonomy/train_val_test_fullplus.csv'), full=True)
        else:
            assert False

        self.scene_list = self.data[partition]
        self.img_types = img_types

        self.data_list = {}
        for img_type in img_types:
            self.data_list[img_type] = []

        for img_type in img_types:
            if not (img_type == 'edge_texture'):
                image_dir = os.path.join(data_dir, img_type)
                images = sorted(os.listdir(image_dir))
                for image in images:
                    self.data_list[img_type].append(os.path.join(image_dir, image))

        # assert False
        self.length = len(self.data_list[self.img_types[0]])
        self._max, self._min = {}, {}
        for img_type in self.img_types:
            self._max[img_type] = -1000000.0
            self._min[img_type] = 100000.0

    def preprocess_depth(self, labels, masks, channels, task):
        labels = torch.from_numpy(labels).float()

        labels_th = []
        for c in channels:
            assert c < len(self.depth_quantiles[task]) - 1

            # get boundary values for the depth segment
            t_min = self.depth_quantiles[task][c]
            # t_max = self.depth_quantiles[task][c+1]
            if task == 'depth_euclidean':
                t_max = self.depth_quantiles[task][c+1]
            else:
                t_max = self.depth_quantiles[task][5]

            # thresholding and re-normalizing
            labels_ = torch.where(masks, labels, t_min * torch.ones_like(labels))
            labels_ = torch.clip(labels_, t_min, t_max)
            labels_ = (labels_ - t_min) / (t_max - t_min)
            labels_th.append(labels_)

        labels = torch.stack(labels_th)
        masks = masks.expand_as(labels)
        
        return labels, masks
    
    def preprocess_edge_texture(self, imgs, channels):
        labels = []
        # detect sobel edge with different set of pre-defined parameters
        for c in channels:
            labels_ = self.sobel_detectors[c].detect(imgs)
            labels.append(labels_)
        
        labels = torch.cat(labels, 0)

        # thresholding and re-normalizing
        labels = torch.clip(labels, 0, self.edge_params['threshold'])
        labels = labels / self.edge_params['threshold']

        masks = torch.ones_like(labels)
        
        return labels, masks

    def preprocess_edge_occlusion(self, labels, masks, channels):
        masks = torch.from_numpy(masks)

        labels_th = []
        labels = torch.where(masks.bool(), labels, torch.zeros_like(labels))
        for c in channels:
            assert c < len(self.edge_thresholds)
            t_max = self.edge_thresholds[c]

            # thresholding and re-normalizing
            labels_ = torch.clip(labels, 0, t_max)
            labels_ = labels_ / t_max
            labels_th.append(labels_)

        labels = torch.stack(labels_th, 0)
        masks = masks.expand_as(labels)
        
        return labels, masks

    def preprocess_segment_semantic(self, labels, channels, drop_background):
        # regard non-support classes as background
        for c in SEMSEG_CLASS_RANGE:
            if c not in channels:
                labels = np.where(labels == c, np.zeros_like(labels), labels)
        
        # re-label support classes
        for i, c in enumerate(sorted(channels)):
            labels = np.where(labels == c, (i + 1) * np.ones_like(labels), labels)

        # one-hot encoding
        labels = torch.from_numpy(labels).long().squeeze(1)
        labels = F.one_hot(labels, len(channels) + 1).permute(2, 0, 1).float()

        if drop_background:
            labels = labels[1:]
        
        masks = torch.ones_like(labels)
        
        return labels, masks

    def __len__(self):
        if self.partition == 'test':
            return self.length//10
        return self.length

    def __getitem__(self, index):
        # Load Image
        output = {}

        img_path = self.data_list['rgb'][index]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.resize_scale:
            image = cv2.resize(image, (self.resize_scale, self.resize_scale))
        
        assert image.dtype == 'uint8', 'error: image.dtype is not uint8.'
        image = self.to_tensor(image)

        output['rgb'] = image
        
        for img_type in self.img_types:
            if not ('rgb' in img_type):
                if (img_type == 'edge_texture'):
                    pass
                else:
                    gt_path = self.data_list[img_type][index]
                if (img_type in ['depth_zbuffer', 'edge_occlusion', 'edge_occlusion', 'keypoints3d', 'principal_curvature']):
                    depth_label_path = self.data_list['depth_euclidean'][index]
                
                label = None
                if not (img_type in ['edge_texture']):
                    label = cv2.imread(gt_path, cv2.IMREAD_UNCHANGED).astype('float32')
                
                depth_label = None
                if (img_type in ['depth_zbuffer', 'edge_occlusion', 'keypoints3d', 'principal_curvature']):
                    depth_label = cv2.imread(depth_label_path, cv2.IMREAD_UNCHANGED).astype('float32')

                # resize data
                if self.resize_scale:
                    if not (img_type in ['edge_texture', 'segment_semantic']):
                        label = cv2.resize(label, (self.resize_scale, self.resize_scale))
                    if (img_type in ['depth_zbuffer', 'edge_occlusion', 'keypoints3d', 'principal_curvature']):
                        depth_label = cv2.resize(depth_label, (self.resize_scale, self.resize_scale))

                if not (img_type in ['edge_texture']):
                    assert label.dtype == 'float32', 'error: label.dtype is not float32.'
                    label = self.to_tensor(label).squeeze()

                # data process
                if img_type == 'segment_semantic':
                    label = label.to(torch.long)
                    label[label == 0] = 1
                    label = label - 1
                    channels = SEMSEG_CLASSES
                    label, valid_mask = self.preprocess_segment_semantic(label, channels, drop_background = True)
                elif img_type == 'normal':
                    label = label / 255
                    label = np.clip(label, 0, 1)
                    valid_mask = np.ones_like(label)
                elif img_type == 'depth_zbuffer':
                    label = np.log((1 + label.numpy())) / np.log(2 ** 16)
                    valid_mask = (torch.from_numpy(depth_label) < 64500)
                    label, valid_mask = self.preprocess_depth(label, valid_mask, range(1), 'depth_zbuffer')
                elif img_type == 'depth_euclidean':
                    valid_mask = (label < 64500)
                    label = np.log((1 + label.numpy())) / np.log(2 ** 16)
                    label, valid_mask = self.preprocess_depth(label, valid_mask, range(5), 'depth_euclidean')
                elif img_type == 'edge_texture':
                    label = valid_mask = None
                    label, valid_mask = self.preprocess_edge_texture(image, range(5))
                elif img_type == 'edge_occlusion':
                    label = label / (2 ** 16)
                    valid_mask = (depth_label < 64500)
                    label, valid_mask = self.preprocess_edge_occlusion(label, valid_mask, range(5))
                elif img_type == 'keypoints2d':
                    label = label / (2 ** 16)
                    label = np.clip(label, 0, 0.005) / 0.005
                    label = label.unsqueeze(0)
                    valid_mask = np.ones_like(label)
                elif img_type == 'keypoints3d':
                    label = label / (2 ** 16)
                    label = label.unsqueeze(0)
                    valid_mask = (depth_label < 64500)
                elif img_type == 'reshading':
                    label = label[:1] / 255
                    label = np.clip(label, 0, 1)
                    valid_mask = np.ones_like(label)
                elif img_type == 'principal_curvature':
                    label = label[1:3] / 255
                    label = np.clip(label, 0, 1)
                    valid_mask = (depth_label < 64500)
                    valid_mask = torch.from_numpy(valid_mask).expand_as(label)
                else:
                    raise ValueError(img_type)
                
                output[img_type] = label
        
        if self.partition == 'train':
            random_size = random.randint(self.crop_size, self.resize_scale)

            x = random.randint(0, self.resize_scale - random_size + 1)
            y = random.randint(0, self.resize_scale - random_size + 1)
            for img_type in self.img_types:
                image = output[img_type]
                image = image[:, y:(y+random_size), x:(x+random_size)]
                image = F.interpolate(image.unsqueeze(0), size=(self.resize_scale, self.resize_scale), mode='bilinear', align_corners=False).squeeze(0)
                output[img_type] = image

        return output


class TaskonomyDataset_test(data.Dataset):
    
    def __init__(self, img_types, data_dir='./data', split='tiny', partition='train', transform=None, resize_scale=None, crop_size=None, fliplr=False):
        
        super(TaskonomyDataset_test, self).__init__()
        
        self.partition = partition
        self.resize_scale = resize_scale
        self.crop_size = crop_size
        self.fliplr = fliplr
        self.class_num = {'class_object': 1000, 'class_scene': 365, 'segment_semantic':18}

        self.to_tensor = transforms.ToTensor()

        self.depth_quantiles = torch.load(os.path.join(data_dir, 'depth_quantiles.pth'))
        # [0.42, 0.62, 0.64, 0.68, 0.7000000000000001, 1.0]
        
        self.edge_params = torch.load(os.path.join(data_dir, 'edge_params.pth'))
        self.sobel_detectors = [SobelEdgeDetector(kernel_size=k, sigma=s) for k, s in self.edge_params['params']]

        self.edge_thresholds = torch.load(os.path.join(data_dir, 'edge_thresholds.pth'))

        def loadSplit(splitFile, full=False):
            dictLabels = {}
            with open(splitFile) as csvfile:
                csvreader = csv.reader(csvfile, delimiter=',')
                next(csvreader, None)
                for i,row in enumerate(csvreader):
                    scene = row[0]
                    if scene == 'woodbine': # missing from the dataset
                        continue
                    if scene == 'wiconisco': # missing 80 images for edge_texture
                        continue
                    no_list = {'brinnon', 'cauthron', 'cochranton', 'donaldson', 'german',
                        'castor', 'tokeland', 'andover', 'rogue', 'athens', 'broseley', 'tilghmanton', 'winooski', 'rosser', 'arkansaw', 'bonnie', 'willow', 'timberon', 'bohemia', 'micanopy', 'thrall', 'annona', 'byers', 'anaheim', 'duarte', 'wyldwood'
                    }
                    new_list = {'ballou', 'tansboro', 'cutlerville', 'macarthur', 'rough', 'darnestown', 'maryhill', 'bowlus', 'tomkins', 'herricks', 'mosquito', 'brinnon', 'gough'}
                    
                    if scene in new_list and full:
                        continue
                    if scene in no_list and (not full):
                        continue
                    is_train, is_val, is_test = row[1], row[2], row[3]
                    if is_train=='1' or is_val=='1':
                        label = 'train'
                    else:
                        label = 'test'

                    if label in dictLabels.keys():
                        dictLabels[label].append(scene)
                    else:
                        dictLabels[label] = [scene]
            return dictLabels

        self.split = split

        if (split == 'muleshoe'):
            self.data = {}
            self.data[partition] = []
            self.data[partition].append('muleshoe')
        elif split == 'tiny':
            self.data = loadSplit(splitFile = os.path.join(data_dir, 'train_val_test_tiny.csv'))
        elif split == 'medium':
            self.data = loadSplit(splitFile = os.path.join(data_dir, 'splits_taskonomy/train_val_test_medium.csv'))
        elif split == 'fullplus':
            self.data = loadSplit(splitFile = os.path.join(data_dir, 'splits_taskonomy/train_val_test_fullplus.csv'), full=True)
        else:
            assert False

        self.scene_list = self.data[partition]
        self.img_types = img_types

        self.data_list = {}
        for img_type in img_types:
            self.data_list[img_type] = []

        for img_type in img_types:
            if not (img_type == 'edge_texture'):
                image_dir = os.path.join(data_dir, img_type)
                images = sorted(os.listdir(image_dir))
                for image in images:
                    self.data_list[img_type].append(os.path.join(image_dir, image))

        # assert False
        self.length = len(self.data_list[self.img_types[0]])
        self._max, self._min = {}, {}
        for img_type in self.img_types:
            self._max[img_type] = -1000000.0
            self._min[img_type] = 100000.0

    def preprocess_depth(self, labels, masks, channels, task):
        labels = torch.from_numpy(labels).float()

        labels_th = []
        for c in channels:
            assert c < len(self.depth_quantiles[task]) - 1

            # get boundary values for the depth segment
            t_min = self.depth_quantiles[task][c]
            # t_max = self.depth_quantiles[task][c+1]
            if task == 'depth_euclidean':
                t_max = self.depth_quantiles[task][c+1]
            else:
                t_max = self.depth_quantiles[task][5]

            # thresholding and re-normalizing
            labels_ = torch.where(masks, labels, t_min * torch.ones_like(labels))
            labels_ = torch.clip(labels_, t_min, t_max)
            labels_ = (labels_ - t_min) / (t_max - t_min)
            labels_th.append(labels_)

        labels = torch.stack(labels_th) # (5, 224, 224)
        masks = masks.expand_as(labels) # (5, 224, 224)
        
        return labels, masks
    
    def preprocess_edge_texture(self, imgs, channels):
        labels = []
        # detect sobel edge with different set of pre-defined parameters
        for c in channels:
            labels_ = self.sobel_detectors[c].detect(imgs)
            labels.append(labels_)
        
        labels = torch.cat(labels, 0)

        # thresholding and re-normalizing
        labels = torch.clip(labels, 0, self.edge_params['threshold'])
        labels = labels / self.edge_params['threshold']

        masks = torch.ones_like(labels)
        
        return labels, masks

    def preprocess_edge_occlusion(self, labels, masks, channels):
        masks = torch.from_numpy(masks)

        labels_th = []
        labels = torch.where(masks.bool(), labels, torch.zeros_like(labels))
        for c in channels:
            assert c < len(self.edge_thresholds)
            t_max = self.edge_thresholds[c]

            # thresholding and re-normalizing
            labels_ = torch.clip(labels, 0, t_max)
            labels_ = labels_ / t_max
            labels_th.append(labels_)

        labels = torch.stack(labels_th, 0)
        masks = masks.expand_as(labels)
        
        return labels, masks

    def preprocess_segment_semantic(self, labels, channels, drop_background):
        # regard non-support classes as background
        for c in SEMSEG_CLASS_RANGE:
            if c not in channels:
                labels = np.where(labels == c, np.zeros_like(labels), labels)
        
        # re-label support classes
        for i, c in enumerate(sorted(channels)):
            labels = np.where(labels == c, (i + 1) * np.ones_like(labels), labels)

        # one-hot encoding
        labels = torch.from_numpy(labels).long().squeeze(1)
        labels = F.one_hot(labels, len(channels) + 1).permute(2, 0, 1).float()

        if drop_background:
            labels = labels[1:]
        
        masks = torch.ones_like(labels)
        
        return labels, masks

    def __len__(self):
        if self.partition == 'test':
            return self.length//10
        return self.length

    def __getitem__(self, index):
        # Load Image
        output = {}
        valid_mask_all = {}

        img_path = self.data_list['rgb'][index]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.resize_scale:
            image = cv2.resize(image, (self.resize_scale, self.resize_scale))
        
        assert image.dtype == 'uint8', 'error: image.dtype is not uint8.'
        image = self.to_tensor(image)

        output['rgb'] = image
        
        for img_type in self.img_types:
            if not ('rgb' in img_type):
                if (img_type == 'edge_texture'):
                    pass
                else:
                    gt_path = self.data_list[img_type][index]
                if (img_type in ['depth_zbuffer', 'edge_occlusion', 'edge_occlusion', 'keypoints3d', 'principal_curvature']):
                    depth_label_path = self.data_list['depth_euclidean'][index]
                
                label = None
                if not (img_type in ['edge_texture']):
                    label = cv2.imread(gt_path, cv2.IMREAD_UNCHANGED).astype('float32')
                
                depth_label = None
                if (img_type in ['depth_zbuffer', 'edge_occlusion', 'keypoints3d', 'principal_curvature']):
                    depth_label = cv2.imread(depth_label_path, cv2.IMREAD_UNCHANGED).astype('float32')

                # resize data
                if self.resize_scale:
                    if not (img_type in ['edge_texture', 'segment_semantic']):
                        label = cv2.resize(label, (self.resize_scale, self.resize_scale))
                    if (img_type in ['depth_zbuffer', 'edge_occlusion', 'keypoints3d', 'principal_curvature']):
                        depth_label = cv2.resize(depth_label, (self.resize_scale, self.resize_scale))

                if not (img_type in ['edge_texture']):
                    assert label.dtype == 'float32', 'error: label.dtype is not float32.'
                    label = self.to_tensor(label).squeeze()

                # data process
                if img_type == 'segment_semantic':
                    label = label.to(torch.long)
                    label[label == 0] = 1
                    label = label - 1
                    channels = SEMSEG_CLASSES
                    label, valid_mask = self.preprocess_segment_semantic(label, channels, drop_background = True)
                elif img_type == 'normal':
                    label = label / 255
                    label = np.clip(label, 0, 1)
                    valid_mask = np.ones_like(label)
                elif img_type == 'depth_zbuffer':
                    label = np.log((1 + label.numpy())) / np.log(2 ** 16)
                    valid_mask = (torch.from_numpy(depth_label) < 64500)
                    label, valid_mask = self.preprocess_depth(label, valid_mask, range(1), 'depth_zbuffer')
                elif img_type == 'depth_euclidean':
                    valid_mask = (label < 64500)
                    label = np.log((1 + label.numpy())) / np.log(2 ** 16)
                    label, valid_mask = self.preprocess_depth(label, valid_mask, range(5), 'depth_euclidean')
                elif img_type == 'edge_texture':
                    label = valid_mask = None
                    label, valid_mask = self.preprocess_edge_texture(image, range(5))
                elif img_type == 'edge_occlusion':
                    label = label / (2 ** 16)
                    valid_mask = (depth_label < 64500)
                    label, valid_mask = self.preprocess_edge_occlusion(label, valid_mask, range(5))
                elif img_type == 'keypoints2d':
                    label = label / (2 ** 16)
                    label = np.clip(label, 0, 0.005) / 0.005
                    label = label.unsqueeze(0)
                    valid_mask = np.ones_like(label)
                elif img_type == 'keypoints3d':
                    label = label / (2 ** 16)
                    label = label.unsqueeze(0)
                    valid_mask = (depth_label < 64500)
                elif img_type == 'reshading':
                    label = label[:1] / 255
                    label = np.clip(label, 0, 1)
                    valid_mask = np.ones_like(label)
                elif img_type == 'principal_curvature':
                    label = label[1:3] / 255
                    label = np.clip(label, 0, 1)
                    valid_mask = (depth_label < 64500)
                    valid_mask = torch.from_numpy(valid_mask).expand_as(label)
                else:
                    raise ValueError(img_type)
                
                output[img_type] = label
                valid_mask_all[img_type] = valid_mask

        return output, valid_mask_all


class FewshotTaskonomy(TaskonomyDataset):

    def __init__(self, shots, *args, **kwargs):
        super(FewshotTaskonomy, self).__init__(*args, **kwargs)

        np.random.seed(20250901)
        self.choose = np.random.randint(self.length, size=shots)

        print(self.choose)

        self.length = shots

    def __getitem__(self, index):
        return super(FewshotTaskonomy, self).__getitem__(self.choose[index])


class PercentageTaskonomy(TaskonomyDataset):

    def __init__(self, perc, *args, **kwargs):
        super(PercentageTaskonomy, self).__init__(*args, **kwargs)

        np.random.seed(20250901)
        self.perc = perc
        self.choose = np.random.randint(self.length, size=int(self.length * self.perc))

        self.length = int(self.length * self.perc)

    def __getitem__(self, index):
        return super(PercentageTaskonomy, self).__getitem__(self.choose[index])



import random
if __name__ == '__main__':
    img_types = ['class_object', 'class_scene', 'depth_euclidean', 'depth_zbuffer', 'normal', 'principal_curvature', 'edge_occlusion', 'edge_texture', 'keypoints2d', 'keypoints3d', 'reshading', 'rgb', 'segment_unsup2d', 'segment_unsup25d']
    
    train_set = TaskonomyDataset(img_types, split='fullplus', partition='train', resize_scale=256, crop_size=224, fliplr=True)
    print(len(train_set))
    A = train_set.__getitem__(len(train_set)-1)
    A = train_set.__getitem__(0)

    train_loader = DataLoader(train_set, batch_size=28*6, num_workers=48, shuffle=False, pin_memory=False)
    for itr, data in tqdm(enumerate(train_loader)):
        pass
